-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf bwd hip #23
base: performance
Are you sure you want to change the base?
Perf bwd hip #23
Conversation
split_tbe_bwd.hip.cpp needed a hpp counterpart to invoke the necessary macros to build all templates for split_tbe_bwd_hip_kernel_
@@ -8,6 +8,10 @@ | |||
{% set wdesc = "weighted" if weighted else "unweighted" %} | |||
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" | |||
#include "fbgemm_gpu/split_embeddings_utils.cuh" | |||
#include "hip_kernel/split_tbe_common_hip.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
#ifdef HIP_PLATFORM_HCC
#endif
Around new includes
@@ -45,6 +45,7 @@ | |||
# An optimization for ROCm | |||
env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have use_rocm being exposed to jinja, we can probably avoid needing an extra "items_per_warp"
static hipFunction_t hip_kernel_func; | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
std::set<int> D_emb_s {64, 128, 192, 256}; | ||
bool hip_opt_kernel_supported = (D_emb_s.find(max_D) != D_emb_s.end()) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we do mixed dimension this check will go away.
IFU targets on upstream 5219dc4
python test/split_table_batched_embeddings_test.py SplitTableBatchedEmbeddingsTest.test_backward_adagrad_fp32_pmSUM
python test/split_table_batched_embeddings_test.py SplitTableBatchedEmbeddingsTest.test_backward_optimizers_adagrad
Can pass above 2 UTs ( by some modification)
emb_t
andgrad_t
combinationexact
for now.weight_decay_mode
in rowwise-adagradtest_backward_adagrad_fp32_pmSUM
)